{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# t-SNE Animations\n",
    "\n",
    "*openTSNE* includes a callback system, with can be triggered every *n* iterations and can also be used to control optimization and when to stop.\n",
    "\n",
    "In this notebook, we'll look at an example and use callbacks to generate an animation of the optimization. In practice, this serves no real purpose other than being fun to look at."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import openTSNE\n",
    "from examples import utils\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.animation as animation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data\n",
    "\n",
    "The preprocessed data set can be downloaded from http://file.biolab.si/opentsne/benchmark/macosko_2015.pkl.gz."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gzip\n",
    "import pickle\n",
    "\n",
    "with gzip.open(\"data/macosko_2015.pkl.gz\", \"rb\") as f:\n",
    "    data = pickle.load(f)\n",
    "\n",
    "x = data[\"pca_50\"]\n",
    "y = data[\"CellType1\"].astype(str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data set contains 44808 samples with 50 features\n"
     ]
    }
   ],
   "source": [
    "print(\"Data set contains %d samples with %d features\" % x.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We pass a callback that will take the current embedding, make a copy (this is important because the embedding is changed inplace during optimization) and add it to a list. We can also specify how often the callbacks should be called. In this instance, we'll call it at every iteration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings = []\n",
    "\n",
    "tsne = openTSNE.TSNE(\n",
    "    perplexity=50, metric=\"cosine\", n_jobs=32, verbose=True,\n",
    "    # The embedding will be appended to the list we defined above, make sure we copy the\n",
    "    # embedding, otherwise the same object reference will be stored for every iteration\n",
    "    callbacks=lambda it, err, emb: embeddings.append(np.array(emb)),\n",
    "    # This should be done on every iteration\n",
    "    callbacks_every_iters=1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------------------------------------\n",
      "TSNE(callbacks=<function <lambda> at 0x7f61b025da60>, callbacks_every_iters=1,\n",
      "     early_exaggeration=12, metric='cosine', n_jobs=32, perplexity=50,\n",
      "     verbose=True)\n",
      "--------------------------------------------------------------------------------\n",
      "===> Finding 150 nearest neighbors using Annoy approximate search using cosine distance...\n",
      "   --> Time elapsed: 15.90 seconds\n",
      "===> Calculating affinity matrix...\n",
      "   --> Time elapsed: 1.29 seconds\n",
      "===> Calculating PCA-based initialization...\n",
      "   --> Time elapsed: 0.29 seconds\n",
      "===> Running optimization with exaggeration=12.00, lr=3734.00 for 250 iterations...\n",
      "Iteration   50, KL divergence 5.0623, 50 iterations in 5.0761 sec\n",
      "Iteration  100, KL divergence 4.9120, 50 iterations in 5.1648 sec\n",
      "Iteration  150, KL divergence 4.8728, 50 iterations in 5.1445 sec\n",
      "Iteration  200, KL divergence 4.8556, 50 iterations in 5.1387 sec\n",
      "Iteration  250, KL divergence 4.8436, 50 iterations in 5.1009 sec\n",
      "   --> Time elapsed: 25.63 seconds\n",
      "===> Running optimization with exaggeration=1.00, lr=44808.00 for 500 iterations...\n",
      "Iteration   50, KL divergence 2.9395, 50 iterations in 5.1579 sec\n",
      "Iteration  100, KL divergence 2.7535, 50 iterations in 6.3857 sec\n",
      "Iteration  150, KL divergence 2.6724, 50 iterations in 8.0214 sec\n",
      "Iteration  200, KL divergence 2.6245, 50 iterations in 9.6104 sec\n",
      "Iteration  250, KL divergence 2.5928, 50 iterations in 12.0700 sec\n",
      "Iteration  300, KL divergence 2.5686, 50 iterations in 13.6955 sec\n",
      "Iteration  350, KL divergence 2.5505, 50 iterations in 14.5318 sec\n",
      "Iteration  400, KL divergence 2.5363, 50 iterations in 15.3540 sec\n",
      "Iteration  450, KL divergence 2.5240, 50 iterations in 17.1487 sec\n",
      "Iteration  500, KL divergence 2.5143, 50 iterations in 18.1806 sec\n",
      "   --> Time elapsed: 120.16 seconds\n",
      "CPU times: user 7min 34s, sys: 13.2 s, total: 7min 47s\n",
      "Wall time: 2min 48s\n"
     ]
    }
   ],
   "source": [
    "%time tsne_embedding = tsne.fit(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have all the iterations in our list, we need to create the animation. We do this here using matplotlib, which is relatively straightforward. Generating the animation can take a long time, so we will save it as a mp4 video file so we can come back to it whenever we want, without having to wait again. Please make sure that `ffmpeg` is installed on your system before running this command."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 12min 28s, sys: 3.3 s, total: 12min 31s\n",
      "Wall time: 12min 32s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "fig, ax = plt.subplots(figsize=(7, 7))\n",
    "ax.set_xticks([]), ax.set_yticks([])\n",
    "\n",
    "colors = list(map(utils.MACOSKO_COLORS.get, y))\n",
    "pathcol = ax.scatter(embeddings[0][:, 0], embeddings[0][:, 1], c=colors, s=1, rasterized=True)\n",
    "\n",
    "def update(embedding, ax, pathcol):\n",
    "    # Update point positions\n",
    "    pathcol.set_offsets(embedding)\n",
    "    \n",
    "    # Adjust x/y limits so all the points are visible\n",
    "    ax.set_xlim(np.min(embedding[:, 0]), np.max(embedding[:, 0]))\n",
    "    ax.set_ylim(np.min(embedding[:, 1]), np.max(embedding[:, 1]))\n",
    "    \n",
    "    return [pathcol]\n",
    "\n",
    "anim = animation.FuncAnimation(\n",
    "    fig, update, fargs=(ax, pathcol), interval=20,\n",
    "    frames=embeddings, blit=True,\n",
    ")\n",
    "\n",
    "anim.save(\"macosko.mp4\", dpi=150, writer=\"ffmpeg\")\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
